{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "cellType": "markdown",
    "uuid": "e5721b04-150d-4004-9c91-d6411076d08c"
   },
   "source": [
    "## 一个简单Word2Vec的例子\n",
    "\n",
    "这个例子展示了如何利用DSW进行构造一个Word2Vec的学习，Word2Vec是NLP训练一个比较基础的数据处理方式，具体原理大家可以仔细阅读论文《[Efficient Estimation of Word Representations in Vector Space](https://arxiv.org/abs/1301.3781)》\n",
    "\n",
    "我们准备一个文章text，然后通过这个文章的信息来学习出文章中单词的一个向量表达，使得单词的空间距离能够体现单词语义距离，使得越靠近的单词语义越相似。\n",
    "\n",
    "首先我们把文章读入到words这个数组中。可以看到这个文章总共包括17005207单词。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "uuid": "b83c4f46-d254-462f-98f0-bdc16910cb7d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "words size 17005207\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import zipfile\n",
    "\n",
    "with zipfile.ZipFile(\"text.zip\") as f:\n",
    "    words = tf.compat.as_str(f.read(f.namelist()[0])).split()\n",
    "    \n",
    "print('words size', len(words))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cellType": "markdown",
    "uuid": "8bb2ab55-8755-434a-b25b-fedb84f57872"
   },
   "source": [
    "然后我们准备一个词典，例子中我们限制一下这个词典最大size为50000, 字典中我们保存文章最多出现49999个词，然后其他词都当做'UNK'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "uuid": "12bdc2d4-d682-42bf-8911-57a4e735a0ab"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "最多5个单词以及出现次数 [('the', 1061396), ('of', 593677), ('and', 416629), ('one', 411764), ('in', 372201)]\n"
     ]
    }
   ],
   "source": [
    "import collections\n",
    "import math\n",
    "\n",
    "vocabulary_size = 50000\n",
    "count = [['UNK', -1]]\n",
    "count.extend(collections.Counter(words).most_common(vocabulary_size - 1))\n",
    "print(\"最多5个单词以及出现次数\", count[1:6])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "uuid": "8e95f430-84e0-4809-ab19-4627747ac0a8"
   },
   "outputs": [],
   "source": [
    "为了后面训练的方便，我们把单词用字典的index来进行标识，并且把原文用这个进行编码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "uuid": "f160185a-374f-4fea-878f-8606fc03f95c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "编码后文章为 [5234, 3081, 12, 6, 195, 2, 3134, 46, 59, 156]\n"
     ]
    }
   ],
   "source": [
    "dictionary = dict()\n",
    "for word, _ in count:\n",
    "    dictionary[word] = len(dictionary)\n",
    "data = list()\n",
    "unk_count = 0\n",
    "for word in words:\n",
    "    if word in dictionary:\n",
    "        index = dictionary[word]\n",
    "    else:\n",
    "        index = 0  # dictionary['UNK']\n",
    "    unk_count += 1\n",
    "    data.append(index)\n",
    "count[0][1] = unk_count\n",
    "\n",
    "print('编码后文章为', data[:10], '...')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cellType": "markdown",
    "uuid": "6b849cc6-ffde-4bfc-8fe2-1fb13c3f1cd7"
   },
   "source": [
    "建立一个方向查找表，等学习完，可以把单词的编码又变回到原单词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "uuid": "38459486-a7a0-435e-8b04-ece276b05c63"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['anarchism', 'originated', 'as', 'a', 'term', 'of', 'abuse', 'first', 'used', 'against']\n"
     ]
    }
   ],
   "source": [
    "reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys()))\n",
    "print([reverse_dictionary[i] for i in data[:10]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cellType": "markdown",
    "uuid": "120248f0-cbac-412e-9300-aca639a7f3ce"
   },
   "source": [
    "接下来我们进入正题，我们构造训练word2vec的样本，也就是skip gram论文中的方法，把这个定义为一个函数，留到后面训练时候用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "uuid": "168efbb7-dc1b-4479-890d-fa28fbd52717"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "data_index = 0\n",
    "def generate_batch(batch_size, num_skips, skip_window):\n",
    "    global data_index\n",
    "    batch = np.ndarray(shape=(batch_size), dtype=np.int32)\n",
    "    labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)\n",
    "    span = 2 * skip_window + 1 # [ skip_window target skip_window ]\n",
    "    buffer = collections.deque(maxlen=span)\n",
    "    for _ in range(span):\n",
    "        buffer.append(data[data_index])\n",
    "        data_index = (data_index + 1) % len(data)\n",
    "    for i in range(batch_size // num_skips):\n",
    "        target = skip_window  # target label at the center of the buffer\n",
    "        targets_to_avoid = [ skip_window ]\n",
    "        for j in range(num_skips):\n",
    "            while target in targets_to_avoid:\n",
    "                target = random.randint(0, span - 1)\n",
    "            targets_to_avoid.append(target)\n",
    "            batch[i * num_skips + j] = buffer[skip_window]\n",
    "            labels[i * num_skips + j, 0] = buffer[target]\n",
    "        buffer.append(data[data_index])\n",
    "        data_index = (data_index + 1) % len(data)\n",
    "    return batch, labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cellType": "markdown",
    "uuid": "b26ab915-1447-405e-93cd-959ee2a1bb9b"
   },
   "source": [
    "我们观察这个训练样本的样子"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "uuid": "44c95b6c-9886-48f4-b4f2-70c1d4115bd6"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3081 originated -> 5234 anarchism\n",
      "3081 originated -> 12 as\n",
      "12 as -> 6 a\n",
      "12 as -> 3081 originated\n",
      "6 a -> 195 term\n",
      "6 a -> 12 as\n",
      "195 term -> 6 a\n",
      "195 term -> 2 of\n"
     ]
    }
   ],
   "source": [
    "batch, labels = generate_batch(batch_size=8, num_skips=2, skip_window=1)\n",
    "for i in range(8):\n",
    "    print(batch[i], reverse_dictionary[batch[i]],\n",
    "        '->', labels[i, 0], reverse_dictionary[labels[i, 0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "uuid": "810eb8a3-cb7e-4c9f-86d0-4385a66c7da5"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
